Predict the brain tumor region with MRI

MRI data preprocessed here 📒, which is a processed version of this dataset 💼

Description of the dataset

This brain tumor dataset containing 3064 T1-weighted contrast-inhanced images from 233 patients with three kinds of brain tumor: meningioma (708 slices), glioma (1426 slices), and pituitary tumor (930 slices). Due to the file size limit of repository, we split the whole dataset into 4 subsets, and achive them in 4 .zip files with each .zip file containing 766 slices.The 5-fold cross-validation indices are also provided.

This data is organized in matlab data format (.mat file). Each file stores a struct containing the following fields for an image:

  • cjdata.label: 1 for meningioma, 2 for glioma, 3 for pituitary tumor 3️⃣
  • cjdata.PID: patient ID
  • cjdata.image: image data
  • cjdata.tumorBorder: a vector storing the coordinates of discrete points on tumor border.
      For example, [x1, y1, x2, y2,...] in which x1, y1 are planar coordinates on tumor border.
      It was generated by manually delineating the tumor border. So we can use it to generate
      binary image of tumor mask.
  • cjdata.tumorMask: a binary image with 1s indicating tumor region

This data was used in the following paper:

  • Cheng, Jun, et al. "Enhanced Performance of Brain Tumor Classification via Tumor Region Augmentation and Partition." PloS one 10.10 (2015).
  • Cheng, Jun, et al. "Retrieval of Brain Tumors by Adaptive Spatial Pooling and Fisher Vector Representation." PloS one 11.6 (2016). Matlab source codes are available on github https://github.com/chengjun583/brainTumorRetrieval

Imports

In [1]:
from forgebox.imports import *
from forgebox.ftorch.cuda import CudaHandler
from tqdm.notebook import tqdm
import pytorch_lightning as pl
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import interact

Config & Locations

In [2]:
DATA = Path("/GCI/brain_mri/")
MATS = DATA/"mats"
NUMPYS = DATA/"npy"
In [3]:
WEIGHTS = DATA/"weights"/"region"
WEIGHTS.mkdir(exist_ok = True)

Meta data, tabulated records

A meta data pandas dataframe makes the info about each image

Column informations
  • pid: patient id
  • img: location of the image numpy
  • mask: location of the mask numpy, our Y in this case
  • boarder: location of the boarder coordinates array, we can ignore this if we are using mask label
  • label: 1 for meningioma, 2 for glioma, 3 for pituitary tumor
  • shape: shape of the image, here we use only the 512x512
  • img_id: original mat file id of the image
In [4]:
df = pd.read_csv(DATA/"meta.csv")
In [5]:
df["img_id"] = df.img.apply(lambda x:int(Path(x).name.split('.')[0]))
df = df.query("shape=='512_512'").sort_values(by=["img_id"]).reset_index(drop=True)
df.sample(10)
Out[5]:
pid img mask boarder label shape img_id
828 101020 /GCI/brain_mri/mats/829.mat_img.npy /GCI/brain_mri/mats/829.mat_mask.npy /GCI/brain_mri/mats/829.mat_bd.npy 2.0 512_512 829
222 101016 /GCI/brain_mri/mats/223.mat_img.npy /GCI/brain_mri/mats/223.mat_mask.npy /GCI/brain_mri/mats/223.mat_bd.npy 1.0 512_512 223
2729 MR049019G /GCI/brain_mri/mats/2745.mat_img.npy /GCI/brain_mri/mats/2745.mat_mask.npy /GCI/brain_mri/mats/2745.mat_bd.npy 2.0 512_512 2745
1567 97461 /GCI/brain_mri/mats/1583.mat_img.npy /GCI/brain_mri/mats/1583.mat_mask.npy /GCI/brain_mri/mats/1583.mat_bd.npy 3.0 512_512 1583
1636 101145 /GCI/brain_mri/mats/1652.mat_img.npy /GCI/brain_mri/mats/1652.mat_mask.npy /GCI/brain_mri/mats/1652.mat_bd.npy 3.0 512_512 1652
345 110116 /GCI/brain_mri/mats/346.mat_img.npy /GCI/brain_mri/mats/346.mat_mask.npy /GCI/brain_mri/mats/346.mat_bd.npy 1.0 512_512 346
1033 106062 /GCI/brain_mri/mats/1037.mat_img.npy /GCI/brain_mri/mats/1037.mat_mask.npy /GCI/brain_mri/mats/1037.mat_bd.npy 3.0 512_512 1037
2955 MR051796B /GCI/brain_mri/mats/2971.mat_img.npy /GCI/brain_mri/mats/2971.mat_mask.npy /GCI/brain_mri/mats/2971.mat_bd.npy 2.0 512_512 2971
448 100572 /GCI/brain_mri/mats/449.mat_img.npy /GCI/brain_mri/mats/449.mat_mask.npy /GCI/brain_mri/mats/449.mat_bd.npy 1.0 512_512 449
2511 MR017260F /GCI/brain_mri/mats/2527.mat_img.npy /GCI/brain_mri/mats/2527.mat_mask.npy /GCI/brain_mri/mats/2527.mat_bd.npy 2.0 512_512 2527

Interactive visualization

Visualization helpers

In [6]:
def vis_patient(pid):
    sub_df = df.query(f"pid=='{pid}'").sort_values(by="img_id")
    img_arr = np.stack(list(np.load(i) for i in sub_df.img))\
        .astype(np.float32)/1000
    mask_arr = np.stack(list(np.load(i) for i in sub_df["mask"]))\
        .astype(np.float32)
    @interact
    def show_mri(i = (1,len(img_arr))):
        print(list(sub_df.img)[i-1])
        rgb_arr = np.stack([
          mask_arr[i-1],
          np.clip(img_arr[i-1]-mask_arr[i-1],0.,1.),
          img_arr[i-1],                  
        ], axis=-1)

        # rgb_arr = img_arr[i-1].astype(np.float32)
        # print(rgb_arr[200:230,200:230])
        display(plt.imshow(rgb_arr))

Preview image and mask

In [7]:
vis_patient('100360')

Learning

Dataset function

In [8]:
class mri_data(Dataset):
    def __init__(self, df: pd.DataFrame):
        super().__init__()
        self.df = df.reset_index(drop = True)
    
    def __len__(self):
        return len(self.df)

    def __repr__(self):
        return f"MRI Dataset:\n\t{len(self.df.pid.unique())} patients, {len(self)} slices"

    def __getitem__(self,idx):
        row = dict(self.df.loc[idx])
        img = np.load(row["img"])
        img = img/(img.max())
        mask = np.load(row["mask"])
        return img[None, ...], mask[None, ...], row['label']-1

def split_by(
    df: pd.DataFrame,
    col: str,
    val_ratio: float=.2
):
    """
    split the train/ valid ratio from the unique value
        of a certain column
        by certain ratio
        
    - col: the certain column
    - val_ratio: certain ratio
    """
    uniques = np.array(list(set(list(df[col]))))
    validation_ids = np.random.choice(
        uniques, size=int(len(uniques)*val_ratio), replace=False)
    val_slice = df[col].isin(validation_ids)
    return df[~val_slice].sample(frac=1.).reset_index(drop=True),\
        df[val_slice].reset_index(drop=True)
In [9]:
train_df, val_df = split_by(df, "pid")
In [10]:
total_ds = mri_data(df)
train_ds = mri_data(train_df)
val_ds = mri_data(val_df)
In [11]:
train_ds, val_ds
Out[11]:
(MRI Dataset:
 	186 patients, 2434 slices,
 MRI Dataset:
 	46 patients, 615 slices)
In [12]:
x,y,z = train_ds[5]
x.shape,y.shape,z
Out[12]:
((1, 512, 512), (1, 512, 512), 1.0)

Mean & Standard Variation

Mean & standard variation of the entire dataset, we need them for the preprocessing layer normalization

In [13]:
all_x = []
for i in tqdm(range(len(total_ds))):
    x,yy,zz = total_ds[i]
    all_x.append(np.array([x.mean(), x.std()]))

In [14]:
all_arr = np.array(all_x)
x_mean, x_std = all_arr.mean(0)
x_mean, x_std
Out[14]:
(0.15574614257151373, 0.16054656673109854)
In [15]:
all_arr[:,0].min(), all_arr[:,0].max(),all_arr[:,1].min(), all_arr[:,1].max()
Out[15]:
(0.055120470304629875,
 0.29819597127486247,
 0.08168585048503188,
 0.24941559801719104)

Model structure

We are using segmentation CNN model

In [16]:
# !pip install -q segmentation-models-pytorch
In [17]:
import segmentation_models_pytorch as smp

Experiments with Unet

In [18]:
model = smp.Unet(
    "efficientnet-b5",
    encoder_weights="imagenet",
    in_channels=1,
    classes=1, 
    )

Test model pipeline

In [19]:
model(torch.FloatTensor(x)[None,...]).shape
Out[19]:
torch.Size([1, 1, 512, 512])

Lightning Data Module

In [20]:
class PlData(pl.LightningDataModule):
    def __init__(self, train_df, val_df, bs):
        super().__init__()
        self.bs = bs
        self.train_df = train_df
        self.val_df = val_df
        self.train_ds = mri_data(self.train_df)
        self.val_ds = mri_data(self.val_df)

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            shuffle=True,
            num_workers=8,
            batch_size=self.bs
        )

    def val_dataloader(self):
        """
        validation dataloader loader
        batch size = train batch size x 2
        """
        return DataLoader(
            self.val_ds,
            shuffle=False,
            num_workers=8,
            batch_size=self.bs * 2
        )

Lightning Module

In [21]:
class PlMRIModel(pl.LightningModule):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.sigmoid = nn.Sigmoid()
        self.crit = nn.BCEWithLogitsLoss()
        self.accuracy_f = pl.metrics.Accuracy()
        self.prec = pl.metrics.Precision()
        self.rec = pl.metrics.Recall()

    def forward(self, x):
        return self.base((x - x_mean) / x_std)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.base.parameters(), lr=1e-4)
    
    def norm_x(self, x):
        return (x - x_mean) / x_std
    
    def calc_all_metrics(
        self,
        y_, y, is_train
    ):
        phase = "train" if is_train else "val"
        
        logits = self.sigmoid(y_)
        acc = self.accuracy_f(logits, y)
        precision = self.prec(logits, y)
        recall = self.rec(logits,y)
        
        self.log(f'{phase}_acc', acc)
        self.log(f'{phase}_prec', precision)
        self.log(f'{phase}_rec', recall)

    def training_step(self, batch, batch_idx):
        x,y,z = batch
        x = self.norm_x(x).float(); y=y.float()
        y_ = self(x)
        loss = self.crit(y_, y)
        
        self.log('train_loss', loss)
        self.calc_all_metrics(y_, y, True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x,y,z = batch
        x = self.norm_x(x).float(); y=y.float()
        y_ = self(x)
        loss = self.crit(y_, y)

        self.log('val_loss', loss)
        self.calc_all_metrics(y_, y, False)
        
        return loss
In [22]:
pl_data = PlData(train_df, val_df, bs=8)
pl_model = PlMRIModel(model)

Training configuration

Logging and callbacks

In [24]:
# loggers
logger = pl.loggers.TensorBoardLogger("/GCI/tensorboard/brain_mri/", name="region")

# callbacks
early = pl.callbacks.EarlyStopping(monitor="val_acc")
saving = pl.callbacks.ModelCheckpoint(str(WEIGHTS), monitor="val_acc", save_top_k = 5, mode="max")
/anaconda3/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Checkpoint directory /GCI/brain_mri/weights/region exists and is not empty.
  warnings.warn(*args, **kwargs)
In [25]:
cu = CudaHandler()
dev = cu.idle().device
>>> 2 cuda devices found >>>
Device 0: 
	name:Tesla V100-PCIE-32GB
	used:4MB	free:32506MB
Device 1: 
	name:Tesla V100-PCIE-32GB
	used:4MB	free:32506MB
cuda stats refreshed
Found the most idle GPU: cuda:0, 32506 MB Mem remained
In [26]:
trainer = pl.Trainer(
    logger=logger,
    callbacks=[early, saving],
    checkpoint_callback=True,
    gpus=[dev.index],
    fast_dev_run=False,
)
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

Some progressbar to watch

In [27]:
trainer.fit(pl_model,pl_data)
  | Name       | Type              | Params
-------------------------------------------------
0 | base       | Unet              | 31.2 M
1 | sigmoid    | Sigmoid           | 0     
2 | crit       | BCEWithLogitsLoss | 0     
3 | accuracy_f | Accuracy          | 0     
4 | prec       | Precision         | 0     
5 | rec        | Recall            | 0     
-------------------------------------------------
31.2 M    Trainable params
0         Non-trainable params
31.2 M    Total params

Out[27]:
1

Visualize Prediction

Switch to evaluation mode

In [29]:
WEIGHTS.ls()
Out[29]:
['epoch=8-step=2807.ckpt',
 'epoch=11-step=3659.ckpt',
 'epoch=9-step=3049.ckpt',
 'epoch=6-step=2134.ckpt',
 'epoch=10-step=3354.ckpt',
 'epoch=8-step=2744.ckpt']
In [34]:
PlMRIModel()
In [44]:
pl_model = PlMRIModel.load_from_checkpoint(
    WEIGHTS/'epoch=8-step=2807.ckpt', base_model = model)
In [45]:
pl_model = pl_model.eval()

An array to array pipeline

In [87]:
def pred(x: np.array) -> np.array:
    """
    predict mask array from image array
    """
    with torch.no_grad():
        return pl_model.sigmoid(
            pl_model(torch.FloatTensor(x).cuda(0))).cpu().detach().numpy()[0,0]>.5

View the validation prediction interactively

  • Red, prediction
  • Green, label
  • Blue, input
In [88]:
def see_val(idx):
    x,y,z = val_ds[idx]
    y_ = pred(x[None,:])
    img_arr = np.stack(
        [y_ * 180,y[0] * 180,x[0]*255,],
        axis=-1)
    return Image.fromarray(
        (img_arr.astype(np.float32)).astype(np.byte),
        mode="RGB")
In [89]:
pl_model = pl_model.cuda(0)
imgs = list(see_val(i) for i in tqdm(range(len(val_ds))))
pl_model = pl_model.cpu()

In [90]:
from forgebox.images.widgets import view_images
In [93]:
view_images(*imgs)()
In [ ]: